package db

import (
	"context"
	"gitee.com/jiangjiali/leaf/conf"
	"gitee.com/jiangjiali/leaf/log"
	"reflect"
	"sync"
	"time"

	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
	"go.mongodb.org/mongo-driver/mongo/readpref"
)

// Session mongo session
type Session struct {
	client      *mongo.Client
	collection  *mongo.Collection
	maxPoolSize uint64
	db          string
	uri         string
	m           sync.RWMutex
	filter      interface{}
	limit       *int64
	project     interface{}
	skip        *int64
	sort        interface{}
}

// New 新建
func New() *Session {
	s := &Session{}
	if conf.MongoAddr != "" {
		s.uri = conf.MongoAddr
	} else {
		s.uri = "mongodb://localhost:27017"
	}
	s.SetPoolLimit(conf.MongoPool)
	if err := s.Connect(); err != nil {
		log.Fatal("数据库连接失败：%v", err)
	}
	return s
}

var S *Session

// Init 初始化
func Init() {
	S = New()
}

// Disconnect 断开连接
func Disconnect() {
	if err := S.client.Disconnect(context.TODO()); err != nil {
		return
	}
}

// C 返回集合
func (s *Session) C(collection string) *Collection {
	s.m.Lock()
	defer s.m.Unlock()
	if len(s.db) == 0 {
		s.db = "test"
	}
	d := &Database{database: s.client.Database(s.db)}
	return &Collection{collection: d.database.Collection(collection)}
}

// SetUri 指定服务器连接uri
func (s *Session) SetUri(uri string) {
	s.m.Lock()
	defer s.m.Unlock()
	s.uri = uri
}

// SetPoolLimit 指定服务器连接池的最大值
func (s *Session) SetPoolLimit(limit uint64) {
	s.m.Lock()
	defer s.m.Unlock()
	s.maxPoolSize = limit
}

// Connect 连接服务器
func (s *Session) Connect() error {
	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
	defer cancel()
	opt := options.Client().ApplyURI(s.uri)
	opt.SetMaxPoolSize(s.maxPoolSize)

	client, err := mongo.NewClient(opt)
	if err != nil {
		return err
	}

	err = client.Connect(ctx)
	if err != nil {
		return err
	}

	if err != nil {
		return err
	}
	s.client = client
	return nil
}

// Ping 验证客户端是否可以连接到拓扑。如果readPreference为nil，则将使用客户端的默认读取首选项。
func (s *Session) Ping() error {
	return s.client.Ping(context.TODO(), readpref.Primary())
}

// Client 返回客户端
func (s *Session) Client() *mongo.Client {
	return s.client
}

// Disconnect 断开连接
func (s *Session) Disconnect() {
	if err := s.client.Disconnect(context.TODO()); err != nil {
		return
	}
}

// DB 返回表示命名数据库的值
func (s *Session) DB(db string) *Database {
	s.m.Lock()
	defer s.m.Unlock()
	return &Database{database: s.client.Database(db)}
}

// Limit 指定结果数的限制，负限制意味着返回一批.
func (s *Session) Limit(limit int64) *Session {
	s.limit = &limit
	return s
}

// Skip 指定返回前要跳过的文档数。对于小于3.2的服务器版本，默认值为0
func (s *Session) Skip(skip int64) *Session {
	s.skip = &skip
	return s
}

// Sort 指定返回文档的顺序
func (s *Session) Sort(sort interface{}) *Session {
	s.sort = sort
	return s
}

// One 最多返回一个与模型匹配的文档
func (s *Session) One(result interface{}) error {
	var err error
	data, err := s.collection.FindOne(context.TODO(), s.filter).DecodeBytes()

	if err != nil {
		return err
	}

	err = bson.Unmarshal(data, result)
	return err
}

// All 查找全部
func (s *Session) All(result interface{}) error {
	resulTypeValue := reflect.ValueOf(result)
	if resulTypeValue.Kind() != reflect.Ptr {
		panic(any("result argument must be a slice address"))
	}
	sliceValue := resulTypeValue.Elem()

	if sliceValue.Kind() == reflect.Interface {
		sliceValue = sliceValue.Elem()
	}
	if sliceValue.Kind() != reflect.Slice {
		panic(any("result argument must be a slice address"))
	}

	sliceValue = sliceValue.Slice(0, sliceValue.Cap())
	element := sliceValue.Type().Elem()
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	var err error

	opt := options.Find()

	if s.sort != nil {
		opt.SetSort(s.sort)
	}

	if s.limit != nil {
		opt.SetLimit(*s.limit)
	}

	if s.skip != nil {
		opt.SetSkip(*s.skip)
	}

	cur, err := s.collection.Find(ctx, s.filter, opt)
	defer cur.Close(ctx)
	if err != nil {
		return err
	}
	if err = cur.Err(); err != nil {
		return err
	}
	i := 0
	for cur.Next(ctx) {
		elem := reflect.New(element)
		if err = bson.Unmarshal(cur.Current, elem.Interface()); err != nil {
			return err
		}
		sliceValue = reflect.Append(sliceValue, elem.Elem())
		i++
	}
	resulTypeValue.Elem().Set(sliceValue.Slice(0, i))
	return nil
}
